import json
import time
from pathlib import Path

import numpy as np
from PIL import Image
from tqdm import tqdm

from cache_image_editing import CIE
from utils_loc import load_config, init_model, seed_everything, mask_decode
from utils_loc.layer_utils import custom_module_register

if __name__ == "__main__":
    s_time = time.time()
    config = load_config()
    pipe, model_key = init_model(device=config.device,
                                 sd_version=config.sd_version,
                                 model_key=config.model_key,
                                 model_path=config.model_path,
                                 weight_dtype=config.float_precision)

    config.model_key = model_key
    batch_size = config.batch_size
    seed_everything(config.seed)

    custom_module_register(pipe, config)

    image_editor = CIE(pipe, config)

    time_cost = 0
    edit_category_list = config.edit_category_list

    editing_instruction_path = Path(config.input_path)
    original_prompts = []
    editing_prompts = []
    image_paths = []

    with open(editing_instruction_path, "r") as f:
        editing_instruction = json.load(f)

        for key, item in tqdm(editing_instruction.items()):

            if item["editing_type_id"] not in edit_category_list:
                continue

            original_prompt = item["original_prompt"].replace("[", "").replace("]", "")
            editing_prompt = item["editing_prompt"].replace("[", "").replace("]", "")
            image_path = item["image_path"]
            editing_instruction = item["editing_instruction"]
            blended_word = item["blended_word"].split(" ") if item["blended_word"] != "" else []
            mask = Image.fromarray(np.uint8(mask_decode(item["mask"])[:, :, np.newaxis].repeat(3, 2))).convert("L")

            original_prompts.append(original_prompt)
            editing_prompts.append(editing_prompt)
            image_paths.append(image_path)

            # if image_paths is not empty, and the length of image_paths is less than batch_size, append the current item
            if len(image_paths) < batch_size:
                continue
            else:
                # if the length of image_paths is equal to batch_size, process the batch
                image_editor.set_prompts_and_paths(original_prompts, editing_prompts, image_paths)
                image_editor()
                original_prompts = []
                editing_prompts = []
                image_paths = []
        if len(image_paths) > 0:
            image_editor.set_prompts_and_paths(original_prompts, editing_prompts, image_paths)
            image_editor()

    print(f"Total time cost: {time_cost}")
